from nebula3.gclient.net import ConnectionPool
from nebula3.Config import Config
import logging
import socket
from app.base.logger import setup_logger

# 使用新的日志配置
logger = setup_logger("nebula_client")

class NebulaClient:
    """
    NebulaGraph客户端封装类，提供连接和查询功能
    """
    def __init__(self):
        self.connection_pool = None
        self.space_name = None
        self.connected = False
        self.error_message = None
    
    def connect(self, ip, port, user, password, space_name):
        """
        连接NebulaGraph数据库
        
        Args:
            ip: 数据库IP地址
            port: 数据库端口
            user: 用户名
            password: 密码
            space_name: 图空间名
        
        Returns:
            bool: 连接是否成功
        """
        try:
            # 验证IP地址格式
            try:
                # 尝试验证IP地址的有效性
                socket.inet_aton(ip)
                logger.info(f"正在连接到 NebulaGraph 数据库: {ip}:{port}, 空间: {space_name}")
            except socket.error:
                # 如果不是有效的IP地址，可能是主机名
                logger.warning(f"提供的地址 {ip} 不是有效的IP地址格式，尝试作为主机名处理")
                try:
                    resolved_ip = socket.gethostbyname(ip)
                    logger.info(f"主机名 {ip} 解析为IP: {resolved_ip}")
                    ip = resolved_ip
                except socket.gaierror as e:
                    logger.error(f"无法解析主机名 {ip}: {str(e)}")
                    self.error_message = f"无法解析主机名: {ip}"
                    return False
            
            # 确保端口是整数
            try:
                port = int(port)
                if port <= 0 or port > 65535:
                    logger.error(f"无效的端口号: {port}，端口范围应为1-65535")
                    self.error_message = f"端口号无效: {port}"
                    return False
            except (ValueError, TypeError):
                logger.error(f"端口号必须是整数: {port}")
                self.error_message = "端口号必须是整数"
                return False
            
            config = Config()
            config.max_connection_pool_size = 10
            
            # 创建连接池
            self.connection_pool = ConnectionPool()
            
            # 按照Nebula Python客户端的要求构建地址列表
            # 格式应该是: [("ip", port)]
            address_list = [(ip, port)]
            logger.debug(f"准备初始化连接池，地址列表: {address_list}")
            
            # 初始化连接池
            ok = self.connection_pool.init(address_list, config)
            if not ok:
                logger.error(f"连接池初始化失败: {address_list}")
                self.error_message = "连接池初始化失败"
                return False
            
            # 创建客户端连接
            logger.debug(f"正在创建客户端会话: user={user}")
            client = self.connection_pool.get_session(user, password)
            if not client:
                logger.error(f"创建会话失败: user={user}")
                self.error_message = "创建会话失败，请检查用户名和密码"
                self.connection_pool.close()
                self.connection_pool = None
                return False
                
            # 尝试直接使用空间，而不是先列出所有空间
            # Nebula Graph可能对SHOW SPACES有权限限制
            logger.debug(f"直接尝试切换到图空间: {space_name}")
            resp = client.execute(f"USE {space_name}")
            
            # 检查响应
            if not resp.is_succeeded():
                # 这里使用is_succeeded()方法检查执行是否成功
                error_code = resp.error_code()
                error_msg = resp.error_msg()
                logger.error(f"使用图空间失败: 错误码={error_code}, 错误信息={error_msg}, 空间名={space_name}")
                
                # 尝试获取可用的空间列表
                try:
                    spaces_resp = client.execute("SHOW SPACES")
                    available_spaces = []
                    
                    if spaces_resp.is_succeeded() and spaces_resp.rows():
                        for record in spaces_resp.rows():
                            for col_name, value in zip(spaces_resp.keys(), record.values):
                                if col_name == "Name":
                                    available_spaces.append(value.as_string())
                        
                        self.error_message = f"图空间 '{space_name}' 不可用。可用的图空间: {', '.join(available_spaces)}"
                    else:
                        self.error_message = f"无法使用图空间 '{space_name}'。且无法获取可用空间列表，请检查权限。"
                except Exception as e:
                    logger.exception(f"获取空间列表失败: {str(e)}")
                    self.error_message = f"无法使用图空间 '{space_name}'：{error_msg}"
                
                # 关闭连接
                self.connection_pool.close()
                self.connection_pool = None
                client.release()
                return False
            
            # 保存连接信息
            self.ip = ip
            self.port = port
            self.username = user
            self.password = password
            self.space_name = space_name
            self.connected = True
            self.error_message = None
            
            client.release()
            logger.info(f"成功连接到 NebulaGraph 数据库: {ip}:{port}, 空间: {space_name}")
            return True
            
        except Exception as e:
            logger.exception(f"连接NebulaGraph失败: {str(e)}")
            self.error_message = f"连接异常: {str(e)}"
            if self.connection_pool:
                self.connection_pool.close()
                self.connection_pool = None
            return False
    
    def disconnect(self):
        """断开连接"""
        if self.connection_pool:
            logger.info("断开NebulaGraph数据库连接")
            self.connection_pool.close()
            self.connection_pool = None
            self.connected = False
    
    def execute_query(self, query):
        """
        执行查询
        
        Args:
            query: 查询语句
        
        Returns:
            查询结果或None
        """
        if not self.connected or not self.connection_pool:
            logger.error("未连接到数据库，无法执行查询")
            return None
        
        try:
            logger.debug(f"执行查询: {query[:200]}...")
            # 使用当前会话的用户名和密码
            username = getattr(self, 'username', 'root')
            password = getattr(self, 'password', 'nebula')
            client = self.connection_pool.get_session(username, password)
            if not client:
                logger.error("获取会话失败")
                return None
            
            # 确保已选择图空间
            space_name = getattr(self, 'space_name', None)
            if space_name:
                logger.debug(f"确保选择图空间: {space_name}")
                use_resp = client.execute(f"USE {space_name}")
                if not use_resp.is_succeeded():
                    error_code = use_resp.error_code()
                    error_msg = use_resp.error_msg()
                    logger.error(f"选择图空间失败: 错误码={error_code}, 错误信息={error_msg}, 空间={space_name}")
                    client.release()
                    return None
            else:
                logger.error("没有设置图空间名称")
                client.release()
                return None
            
            # 执行实际查询
            resp = client.execute(query)
            client.release()
            
            # 使用is_succeeded()方法检查执行是否成功
            if not resp.is_succeeded():
                error_code = resp.error_code()
                error_msg = resp.error_msg()
                logger.error(f"查询执行失败: 错误码={error_code}, 错误信息={error_msg}, 查询: {query[:200]}...")
                return None
                
            logger.debug(f"查询执行成功，返回 {len(resp.rows()) if resp.rows() else 0} 行")
            return resp
            
        except Exception as e:
            logger.exception(f"执行查询失败: {str(e)}, 查询: {query[:200]}...")
            return None
    
    def search_method_by_name(self, method_name):
        """
        根据方法名查询方法节点
        
        Args:
            method_name: 方法全名或部分名称
        
        Returns:
            节点列表
        """
        query = f"""
        MATCH (v:function)
        WHERE v.function.full_name STARTS WITH "{method_name}"
        RETURN id(v) as node_id, v.function.name, v.function.full_name, v.function.type, v.function.visibility LIMIT 100
        """
        
        result = self.execute_query(query)
        if not result:
            return []
        
        # 处理结果集
        nodes = []
        for record in result.rows():
            # 根据提供的示例代码，直接从 values 属性获取值
            values = record.values
            # 跳过无效数据
            if len(values) < 3 or not values[2].value:
                continue
                
            # 将查询结果转换为字典
            try:
                node_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                node = {
                    "node_id": node_id,  # 使用node_id字段名
                    "id": node_id,       # 保持向后兼容
                    "name": values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value,
                    "full_name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                    "type": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                    "visibility": values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value
                }
                nodes.append(node)
            except Exception as e:
                logger.error(f"处理节点数据时出错: {str(e)}")
                continue
            
        return nodes
    
    def get_upstream_methods(self, method_full_name, path_depth=1):
        """
        获取指定方法的上游方法
        
        Args:
            method_full_name: 方法全名
            path_depth: 路径深度，默认为1，-1表示无限深度
        
        Returns:
            上游方法节点列表及其关系
        """
        # 设置路径深度
        if path_depth <= 0 and path_depth != -1:
            path_depth = 1
        
        # 对于无限深度，设置较大但有限的值
        if path_depth == -1:
            path_depth = 10
        
        # 首先获取起始节点的VID
        start_vid = None
        vid_query = f"""
        MATCH (v:function)
        WHERE v.function.full_name == "{method_full_name}"
        RETURN id(v) as vid
        """
        
        result = self.execute_query(vid_query)
        if result and result.rows() and len(result.rows()) > 0:
            vid_value = result.rows()[0].values[0]
            start_vid = vid_value.value.decode('utf-8') if isinstance(vid_value.value, bytes) else vid_value.value
        
        if not start_vid:
            logger.error(f"无法找到方法的VID: {method_full_name}")
            return {"nodes": [], "edges": []}
        
        # 构建GET SUBGRAPH查询 - 注意这里使用IN方向，表示上游节点
        logger.info(f"使用GET SUBGRAPH查询上游树: 方法={method_full_name}, 深度={path_depth}")
        subgraph_query = f"""
        GET SUBGRAPH WITH PROP {path_depth} STEPS FROM "{start_vid}" IN calls,out_calls,implemented_by,overridden_by,super_calls,interface_calls,subtype_calls,injection_calls
        YIELD VERTICES AS nodes, EDGES AS relationships
        """
        
        try:
            # 执行GET SUBGRAPH查询
            result = self.execute_query(subgraph_query)
            
            if not result or not result.rows():
                logger.warning(f"GET SUBGRAPH查询没有返回结果: {method_full_name}")
                return {"nodes": [], "edges": []}
            
            # 处理结果
            all_nodes = {}
            all_edges = []
            
            # 处理GET SUBGRAPH返回的结果
            for row_idx, row in enumerate(result.rows()):
                try:
                    # 获取节点集合和边集合
                    vertices_value = row.values[0]  # 第一个值是节点列表
                    edges_value = row.values[1]     # 第二个值是边列表
                    
                    # 处理节点数据
                    if vertices_value and hasattr(vertices_value, 'value'):
                        try:
                            # 获取节点NList
                            vertices_list = vertices_value.value.values
                            
                            # 检查列表长度
                            if hasattr(vertices_list, '__len__'):
                                list_length = len(vertices_list)
                                logger.debug(f"顶点列表长度: {list_length}")
                                
                                # 遍历顶点列表
                                for i in range(list_length):
                                    try:
                                        vertex_value = vertices_list[i]
                                        if not vertex_value or not hasattr(vertex_value, 'value'):
                                            continue
                                            
                                        # 获取顶点对象
                                        vertex = vertex_value.value
                                        
                                        # 检查是否是有效的顶点
                                        if not vertex or not hasattr(vertex, 'vid'):
                                            continue
                                            
                                        # 获取顶点ID
                                        vid_obj = vertex.vid
                                        if not vid_obj or not hasattr(vid_obj, 'value'):
                                            continue
                                            
                                        vid = vid_obj.value.decode('utf-8') if isinstance(vid_obj.value, bytes) else vid_obj.value
                                        
                                        # 获取顶点标签和属性
                                        props = {}
                                        if hasattr(vertex, 'tags') and vertex.tags:
                                            for tag in vertex.tags:
                                                if hasattr(tag, 'name') and tag.name == b'function':
                                                    tag_props = tag.props if hasattr(tag, 'props') else {}
                                                    props = tag_props
                                                    break
                                        
                                        # 构建节点对象
                                        name = self.get_prop_value(props, b"name", "")
                                        full_name = self.get_prop_value(props, b"full_name", "")
                                        node_type = self.get_prop_value(props, b"type", "")
                                        visibility = self.get_prop_value(props, b"visibility", "")
                                        isLibrary = self.get_prop_value(props, b"is_library", "")
                                        
                                        # 添加节点类型标记
                                        node = {
                                            "id": vid,
                                            "properties": {
                                                "name": name,
                                                "full_name": full_name,
                                                "type": node_type,
                                                "visibility": visibility,
                                                # 标记源节点
                                                "source_node": vid == start_vid,
                                                # 标记DAO节点（如果类名以dao结尾）
                                                "dao_node": "dao" in full_name.lower() if full_name else False,
                                                "is_library": isLibrary
                                            }
                                        }
                                        all_nodes[vid] = node
                                    except Exception as e:
                                        logger.error(f"处理单个顶点[{i}]时出错: {str(e)}")
                        except Exception as e:
                            logger.error(f"处理顶点列表时出错: {str(e)}")
                    
                    # 处理边数据
                    if edges_value and hasattr(edges_value, 'value'):
                        try:
                            # 获取边NList
                            edges_list = edges_value.value.values
                            
                            # 检查列表长度
                            if hasattr(edges_list, '__len__'):
                                list_length = len(edges_list)
                                logger.debug(f"边列表长度: {list_length}")
                                
                                # 遍历边列表
                                for i in range(list_length):
                                    try:
                                        edge_value = edges_list[i]
                                        if not edge_value or not hasattr(edge_value, 'value'):
                                            continue
                                            
                                        # 获取边对象
                                        edge = edge_value.value
                                        
                                        # 检查是否是有效的边
                                        if not edge or not hasattr(edge, 'src') or not hasattr(edge, 'dst'):
                                            continue
                                            
                                        # 获取源顶点ID
                                        src_obj = edge.src
                                        if not src_obj or not hasattr(src_obj, 'value'):
                                            continue
                                            
                                        src_id = src_obj.value.decode('utf-8') if isinstance(src_obj.value, bytes) else src_obj.value
                                        
                                        # 获取目标顶点ID
                                        dst_obj = edge.dst
                                        if not dst_obj or not hasattr(dst_obj, 'value'):
                                            continue
                                            
                                        dst_id = dst_obj.value.decode('utf-8') if isinstance(dst_obj.value, bytes) else dst_obj.value
                                        
                                        # 获取边类型
                                        edge_type = "calls"
                                        if hasattr(edge, 'name'):
                                            edge_type = edge.name.decode('utf-8') if isinstance(edge.name, bytes) else edge.name
                                        
                                        # 构建边对象
                                        edge_obj = {
                                            "source": src_id,
                                            "target": dst_id,
                                            "properties": {
                                                "type": edge_type
                                            }
                                        }
                                        all_edges.append(edge_obj)
                                    except Exception as e:
                                        logger.error(f"处理单条边[{i}]时出错: {str(e)}")
                        except Exception as e:
                            logger.error(f"处理边列表时出错: {str(e)}")
                
                except Exception as e:
                    logger.error(f"处理GET SUBGRAPH行数据时出错: {str(e)}")
                    continue
            
            logger.info(f"GET SUBGRAPH查询完成: 收集了 {len(all_nodes)} 个节点, {len(all_edges)} 条边")
            
            return {
                "nodes": list(all_nodes.values()),
                "edges": all_edges
            }
            
        except Exception as e:
            logger.error(f"GET SUBGRAPH查询或解析出错: {str(e)}")
            return {"nodes": [], "edges": []}
    
    def get_downstream_methods(self, method_full_name, path_depth=1):
        """
        获取指定方法的下游方法
        
        Args:
            method_full_name: 方法全名
            path_depth: 路径深度，默认为1，-1表示无限深度
        
        Returns:
            下游方法节点列表及其关系
        """
        # 设置路径深度
        if path_depth <= 0 and path_depth != -1:
            path_depth = 1
        
        # 对于无限深度，设置较大但有限的值
        if path_depth == -1:
            path_depth = 10
        
        # 首先获取起始节点的VID
        start_vid = None
        vid_query = f"""
        MATCH (v:function)
        WHERE v.function.full_name == "{method_full_name}"
        RETURN id(v) as vid
        """
        
        result = self.execute_query(vid_query)
        if result and result.rows() and len(result.rows()) > 0:
            vid_value = result.rows()[0].values[0]
            start_vid = vid_value.value.decode('utf-8') if isinstance(vid_value.value, bytes) else vid_value.value
        
        if not start_vid:
            logger.error(f"无法找到方法的VID: {method_full_name}")
            return {"nodes": [], "edges": []}
        
        # 构建GET SUBGRAPH查询
        logger.info(f"使用GET SUBGRAPH查询下游树: 方法={method_full_name}, 深度={path_depth}")
        subgraph_query = f"""
        GET SUBGRAPH WITH PROP {path_depth} STEPS FROM "{start_vid}" OUT calls,out_calls,implemented_by,overridden_by,super_calls,interface_calls,subtype_calls,injection_calls
        YIELD VERTICES AS nodes, EDGES AS relationships
        """
        
        try:
            # 执行GET SUBGRAPH查询
            result = self.execute_query(subgraph_query)
            
            if not result or not result.rows():
                logger.warning(f"GET SUBGRAPH查询没有返回结果: {method_full_name}")
                return {"nodes": [], "edges": []}
            
            # 处理结果
            all_nodes = {}
            all_edges = []
            
            # 处理GET SUBGRAPH返回的结果
            for row_idx, row in enumerate(result.rows()):
                try:
                    # 获取节点集合和边集合
                    vertices_value = row.values[0]  # 第一个值是节点列表
                    edges_value = row.values[1]     # 第二个值是边列表
                    
                    # 处理节点数据
                    if vertices_value and hasattr(vertices_value, 'value'):
                        try:
                            # 获取节点NList
                            vertices_list = vertices_value.value.values
                            
                            # 检查列表长度
                            if hasattr(vertices_list, '__len__'):
                                list_length = len(vertices_list)
                                logger.debug(f"顶点列表长度: {list_length}")
                                
                                # 遍历顶点列表
                                for i in range(list_length):
                                    try:
                                        vertex_value = vertices_list[i]
                                        if not vertex_value or not hasattr(vertex_value, 'value'):
                                            continue
                                            
                                        # 获取顶点对象
                                        vertex = vertex_value.value
                                        
                                        # 检查是否是有效的顶点
                                        if not vertex or not hasattr(vertex, 'vid'):
                                            continue
                                            
                                        # 获取顶点ID
                                        vid_obj = vertex.vid
                                        if not vid_obj or not hasattr(vid_obj, 'value'):
                                            continue
                                            
                                        vid = vid_obj.value.decode('utf-8') if isinstance(vid_obj.value, bytes) else vid_obj.value
                                        
                                        # 获取顶点标签和属性
                                        props = {}
                                        if hasattr(vertex, 'tags') and vertex.tags:
                                            for tag in vertex.tags:
                                                if hasattr(tag, 'name') and tag.name == b'function':
                                                    tag_props = tag.props if hasattr(tag, 'props') else {}
                                                    props = tag_props
                                                    break
                                        
                                        # 构建节点对象
                                        name = self.get_prop_value(props, b"name", "")
                                        full_name = self.get_prop_value(props, b"full_name", "")
                                        node_type = self.get_prop_value(props, b"type", "")
                                        visibility = self.get_prop_value(props, b"visibility", "")
                                        isLibrary = self.get_prop_value(props, b"is_library", "")
                                        
                                        # 添加节点类型标记
                                        node = {
                                            "id": vid,
                                            "properties": {
                                                "name": name,
                                                "full_name": full_name,
                                                "type": node_type,
                                                "visibility": visibility,
                                                "is_library": isLibrary,
                                                # 标记源节点
                                                "source_node": vid == start_vid,
                                                # 标记DAO节点（如果类名以dao结尾）
                                                "dao_node": "dao" in full_name.lower() if full_name else False
                                            }
                                        }
                                        all_nodes[vid] = node
                                    except Exception as e:
                                        logger.error(f"处理单个顶点[{i}]时出错: {str(e)}")
                        except Exception as e:
                            logger.error(f"处理顶点列表时出错: {str(e)}")
                    
                    # 处理边数据
                    if edges_value and hasattr(edges_value, 'value'):
                        try:
                            # 获取边NList
                            edges_list = edges_value.value.values
                            
                            # 检查列表长度
                            if hasattr(edges_list, '__len__'):
                                list_length = len(edges_list)
                                logger.debug(f"边列表长度: {list_length}")
                                
                                # 遍历边列表
                                for i in range(list_length):
                                    try:
                                        edge_value = edges_list[i]
                                        if not edge_value or not hasattr(edge_value, 'value'):
                                            continue
                                            
                                        # 获取边对象
                                        edge = edge_value.value
                                        
                                        # 检查是否是有效的边
                                        if not edge or not hasattr(edge, 'src') or not hasattr(edge, 'dst'):
                                            continue
                                            
                                        # 获取源顶点ID
                                        src_obj = edge.src
                                        if not src_obj or not hasattr(src_obj, 'value'):
                                            continue
                                            
                                        src_id = src_obj.value.decode('utf-8') if isinstance(src_obj.value, bytes) else src_obj.value
                                        
                                        # 获取目标顶点ID
                                        dst_obj = edge.dst
                                        if not dst_obj or not hasattr(dst_obj, 'value'):
                                            continue
                                            
                                        dst_id = dst_obj.value.decode('utf-8') if isinstance(dst_obj.value, bytes) else dst_obj.value
                                        
                                        # 获取边类型
                                        edge_type = "calls"
                                        if hasattr(edge, 'name'):
                                            edge_type = edge.name.decode('utf-8') if isinstance(edge.name, bytes) else edge.name
                                        
                                        # 构建边对象
                                        edge_obj = {
                                            "source": src_id,
                                            "target": dst_id,
                                            "properties": {
                                                "type": edge_type
                                            }
                                        }
                                        all_edges.append(edge_obj)
                                    except Exception as e:
                                        logger.error(f"处理单条边[{i}]时出错: {str(e)}")
                        except Exception as e:
                            logger.error(f"处理边列表时出错: {str(e)}")
                
                except Exception as e:
                    logger.error(f"处理GET SUBGRAPH行数据时出错: {str(e)}")
                    continue
            
            logger.info(f"GET SUBGRAPH查询完成: 收集了 {len(all_nodes)} 个节点, {len(all_edges)} 条边")
            
            return {
                "nodes": list(all_nodes.values()),
                "edges": all_edges
            }
            
        except Exception as e:
            logger.error(f"GET SUBGRAPH查询或解析出错: {str(e)}")
            return {"nodes": [], "edges": []}
    
    # 获取属性值的辅助函数
    def get_prop_value(self, prop_dict, key, default_value=""):
        if key not in prop_dict:
            return default_value
        
        value_obj = prop_dict[key]
        val = value_obj.value
        return val.decode('utf-8') if isinstance(val, bytes) else val
    
    def get_error_message(self):
        """获取最后一次错误信息"""
        return getattr(self, 'error_message', "未知错误") 